# This file reads in atmospheric fluxes from nuflux and propagtes them through the earth using nuSQuIDS. 
# It saves the propagated fluxes to a .npz file so they can be interpolated later, and outputs a plot for verification. 
# adapted by A. Wen from the nuSQuIDS demo file.

# usage: python atm_flux_propagation_fromNuFlux.py H3a_SIBYLL23C
# H3a_SIBYLL23C is an example, the full list of options can be found in nuflux documentation

# imports 
################################################################################################################

# nuSQuIDS: https://github.com/arguelles/nuSQuIDS
import nuSQuIDS as nsq 
# nuflux: https://docs.icecube.aq/nuflux/main/index.html
import nuflux

import matplotlib.pyplot as plt
import numpy as np
import sys


# choose the desired flux model. For choices see here: https://docs.icecube.aq/nuflux/main/fluxes.html
################################################################################################################
flux_type = sys.argv[1]
#example: 'H3a_SIBYLL23C'
flux = nuflux.makeFlux(flux_type)
print('Making fluxes with NuFlux option '+flux_type)

# nusquids setup
################################################################################################################
units = nsq.Const()

interactions = True

E_min = 50.0*units.GeV
E_max = 5000000*units.GeV
E_nodes = 150
energy_nodes = nsq.logspace(E_min,E_max,E_nodes)

cth_min = -1.0
cth_max = 0.1
cth_nodes = 100
cth_nodes = nsq.linspace(cth_min,cth_max,cth_nodes)

neutrino_flavors = 3

nsq_atm = nsq.nuSQUIDSAtm(cth_nodes,energy_nodes,neutrino_flavors,nsq.NeutrinoType.both,interactions)

# loop to populate the nuSQUIDSAtm object with the nuflux fluxes for both charges and all 3 flavors
################################################################################################################
AtmInitialFlux = np.zeros((len(cth_nodes),len(energy_nodes),2,neutrino_flavors))
for ic,cth in enumerate(nsq_atm.GetCosthRange()):
    for ie,E in enumerate(nsq_atm.GetERange()):
        E_GeV = E / 1e9
        AtmInitialFlux[ic][ie][0][0] = flux.getFlux(nuflux.NuE,E_GeV,cth) # nue
        AtmInitialFlux[ic][ie][1][0] = flux.getFlux(nuflux.NuEBar,E_GeV,cth) # bar nue
        AtmInitialFlux[ic][ie][0][1] = flux.getFlux(nuflux.NuMu,E_GeV,cth) # nu mu
        AtmInitialFlux[ic][ie][1][1] = flux.getFlux(nuflux.NuMuBar,E_GeV,cth) # bar nu mu
        AtmInitialFlux[ic][ie][0][2] = flux.getFlux(nuflux.NuTau,E_GeV,cth) # nu tau
        AtmInitialFlux[ic][ie][1][2] = flux.getFlux(nuflux.NuTauBar,E_GeV,cth) # bar nu tau

# nusquids propagation
################################################################################################################   
nsq_atm.Set_initial_state(AtmInitialFlux,nsq.Basis.flavor)

nsq_atm.Set_ProgressBar(True) # progress bar will be printed on terminal

nsq_atm.Set_rel_error(1.0e-15)
nsq_atm.Set_abs_error(1.0e-15)

nsq_atm.EvolveState()

erange = nsq_atm.GetERange()
cthrange = nsq_atm.GetCosthRange()

# make a plot to show the effect of the propagation
################################################################################################################
neutype = 1
phi_mubar = [nsq_atm.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]
neutype = 0
phi_mu = [nsq_atm.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]

phi_e_surface = [flux.getFlux(nuflux.NuE,EE/1e9,-0.5) for EE in erange]
phi_mu_surface = [flux.getFlux(nuflux.NuMu,EE/1e9,-0.5) for EE in erange]
phi_tau_surface = [flux.getFlux(nuflux.NuTau,EE/1e9,-0.5) for EE in erange]

phi_ebar_surface = [flux.getFlux(nuflux.NuEBar,EE/1e9,-0.5) for EE in erange]
phi_mubar_surface = [flux.getFlux(nuflux.NuMuBar,EE/1e9,-0.5) for EE in erange]
phi_taubar_surface = [flux.getFlux(nuflux.NuTauBar,EE/1e9,-0.5) for EE in erange]

plt.figure(figsize = (6,6))

plt.plot(erange/1e9, phi_mu, lw = 2.5, color = "cornflowerblue", label = r"$\nu_\mu$ flux at detector")
plt.plot(erange/1e9, phi_mubar, lw = 2.5, color = "darkblue", label = r"$\overline{\nu}_\mu$ flux at detector")

plt.plot(erange/1e9, phi_e_surface, lw = 1.5, color = "coral", label = r"$\nu_e$ flux at surface", linestyle='--')
plt.plot(erange/1e9, phi_mu_surface, lw = 1.5, color = "cornflowerblue", label = r"$\nu_\mu$ flux at surface", linestyle='--')
plt.plot(erange/1e9, phi_tau_surface, lw = 1.5, color = "lightgreen", label = r"$\nu_\tau$ flux at surface", linestyle='--')

plt.plot(erange/1e9, phi_ebar_surface, lw = 1.5, color = "crimson", label = r"$\overline{\nu}_e$ flux at surface", linestyle='-.')
plt.plot(erange/1e9, phi_mubar_surface, lw = 1.5, color = "navy", label = r"$\overline{\nu}_\mu$ flux at surface", linestyle='-.')
plt.plot(erange/1e9, phi_taubar_surface, lw = 1.5, color = "darkgreen", label = r"$\overline{\nu}_\tau$ flux at surface", linestyle='-.')

plt.loglog()

plt.xlim(erange[0]/1e9,erange[-1]/1e9)
plt.xlabel(r"$E_\nu \: [{\rm GeV}]$")
plt.ylabel(r"$\phi^{atm}_\nu (E_\nu, \cos(\theta)=-0.5) \: [1/({\rm GeV}\:{\rm s}\:{\rm cm}^2\:{\rm sr})]$")

plt.grid()

plt.legend()

plt.savefig(flux_type+'.pdf')

# save the propagated fluxes as a function of energy and theta, they can be interpolated and eval'ed later: 
# (we care about muon flavor only)
################################################################################################################################
AtmMuFinalFlux = np.zeros((2,len(cth_nodes),len(energy_nodes)))
Enodes = np.asarray(energy_nodes) / 1e9
Cthnodes = np.asarray(cth_nodes)
for ic,cth in enumerate(nsq_atm.GetCosthRange()):
    for ie,E in enumerate(nsq_atm.GetERange()):
        AtmMuFinalFlux[0][ic][ie] = nsq_atm.EvalFlavor(1,cth,E,0)
        AtmMuFinalFlux[1][ic][ie] = nsq_atm.EvalFlavor(1,cth,E,1)

np.savez('AtmMuFinalFlux'+flux_type+'.npz', cth_nodes=Cthnodes, energy_nodes=Enodes, flux=AtmMuFinalFlux)
print('Propagated fluxes saved to '+str('AtmMuFinalFlux'+flux_type+'.npz'))